# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import copy
import os
from typing import Any, Callable, Dict, Optional

from tensorrt_llm.llmapi.disagg_utils import (
    ConditionalDisaggConfig,
    DisaggClusterConfig,
    DisaggServerConfig,
    MetadataServerConfig,
    ServerRole,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterManager, WorkerInfo
from tensorrt_llm.serve.metadata_server import JsonDictionary
from tensorrt_llm.serve.openai_client import OpenAIClient
from tensorrt_llm.serve.openai_protocol import (
    ChatCompletionRequest,
    CompletionRequest,
    DisaggregatedParams,
    UCompletionRequest,
    UCompletionResponse,
)
from tensorrt_llm.serve.openai_service import OpenAIService
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
from tensorrt_llm.serve.responses_utils import (
    ResponseHooks,
    UCompletionResponseOrGenerator,
    done_generator,
)
from tensorrt_llm.serve.router import KvCacheAwareRouter, Router


class OpenAIDisaggregatedService(OpenAIService):
    def __init__(
        self,
        config: DisaggServerConfig,
        ctx_router: Router,
        gen_router: Router,
        client_factory: Callable[[Router, ServerRole], OpenAIClient],
        metadata_server: Optional[JsonDictionary] = None,
        metadata_config: Optional[MetadataServerConfig] = None,
        req_timeout_secs: int = 180,
        server_start_timeout_secs: int = 180,
        perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None,
        disagg_cluster_storage: Optional[ClusterStorage] = None,
        health_check_interval_secs: int = 3,
    ):
        self._config = config
        self._ctx_router = ctx_router
        self._gen_router = gen_router
        self._client_factory = client_factory
        self._metadata_server = metadata_server
        self._metadata_config = metadata_config
        self._req_timeout_secs = req_timeout_secs
        self._server_start_timeout_secs = server_start_timeout_secs
        self._perf_metrics_collector = perf_metrics_collector
        self._cluster_storage = disagg_cluster_storage
        self._health_check_interval_secs = health_check_interval_secs

        self._ctx_client = None
        self._gen_client = None
        self._disagg_cluster_manager = None

    async def openai_completion(
        self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
    ) -> UCompletionResponseOrGenerator:
        if not await self.is_ready():
            raise RuntimeError("Cluster is not ready")
        if not isinstance(request.prompt, str):
            # Check if it's a list and contains integers
            if type(request.prompt) is list and len(request.prompt) == 1:
                request.prompt = request.prompt[0]
            elif not isinstance(request.prompt, list) or not all(
                isinstance(x, int) for x in request.prompt
            ):
                raise ValueError(
                    "Disaggregated server currently only supports single string prompt or list of integers in request"
                )

        return await self._send_disagg_request(request, hooks)

    async def openai_chat_completion(
        self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
    ) -> UCompletionResponseOrGenerator:
        if not await self.is_ready():
            raise RuntimeError("Cluster is not ready")
        return await self._send_disagg_request(request, hooks)

    async def _send_disagg_request(
        self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
    ) -> UCompletionResponseOrGenerator:
        if hooks:
            hooks.on_req_begin(request)
        # empty server means client decides which server to use
        reserved_gen_server = None
        reserved_ctx_server = None
        # reserve a gen_server if conditional disagg is needed
        reserved_gen_server, need_ctx = await self._check_conditional_disagg(request)
        need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
        ctx_response = None
        gen_req = request
        if need_ctx:
            ctx_req = self._get_ctx_request(request)
            # ctx generator is empty
            ctx_response = await self._ctx_client.send_request(
                ctx_req, server=reserved_ctx_server, hooks=hooks
            )
            await self._verify_ctx_response(ctx_response)
            gen_req = self._get_gen_request(request, ctx_response)
        if ctx_response is None or self._need_gen(ctx_response):
            return await self._gen_client.send_request(
                gen_req, server=reserved_gen_server, hooks=hooks
            )
        else:
            if request.stream:
                # ctx client will never return a generator when streaming is requested
                # make up for this by returning a done generator
                return done_generator()
            return ctx_response

    def _need_gen(self, response: UCompletionResponse) -> bool:
        if response and response.choices[0].finish_reason not in ["length", "not_finished"]:
            del response.choices[0].disaggregated_params
            return False
        return True

    def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
        ctx_request = copy.deepcopy(request)
        ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
        ctx_request.stream = False
        ctx_request.stream_options = None
        return ctx_request

    def _get_gen_request(
        self,
        request: UCompletionRequest,
        ctx_response: UCompletionResponse,
    ) -> UCompletionRequest:
        request.disaggregated_params = ctx_response.choices[0].disaggregated_params
        request.disaggregated_params.request_type = "generation_only"
        # Replace the string prompt with prompt_tokens_ids
        if isinstance(request, CompletionRequest):
            request.prompt = ctx_response.prompt_token_ids
        elif isinstance(request, ChatCompletionRequest):
            request.prompt_token_ids = ctx_response.prompt_token_ids
        return request

    async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
        if self.conditional_disagg_config:
            assert isinstance(self._gen_router, KvCacheAwareRouter)
            # Query kv cache status and select a best gen_server.
            # The server is reserved for generation request
            gen_server, info = await self._gen_router.get_next_server(request)
            match_length = sum(info["matches"])
            total_length = sum(len(token_list) for token_list in info["token_lists"])
            if (
                match_length == 0
                or total_length - match_length
                > self.conditional_disagg_config.max_local_prefill_length
            ):
                return gen_server, True
            return gen_server, False
        return None, True

    async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool:
        if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1":
            # Hard-code first token, ctx_request_id for testing
            request.disaggregated_params = DisaggregatedParams(
                request_type="generation_only",
                first_gen_tokens=[7],
                ctx_request_id=1,
                encoded_opaque_state=None,
                draft_tokens=None,
            )
            request.ignore_eos = True
            return True
        return False

    async def cluster_info(self) -> Dict[str, Any]:
        cluster_info = {"is_ready": await self.is_ready()}
        if self._disagg_cluster_manager:
            cluster_info.update(await self._disagg_cluster_manager.cluster_info())
        return cluster_info

    async def is_ready(self) -> bool:
        if self._disagg_cluster_manager:
            return await self._disagg_cluster_manager.is_ready()
        return True

    @property
    def disagg_cluster_config(self) -> Optional[DisaggClusterConfig]:
        return self._config.disagg_cluster_config

    @property
    def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]:
        return self._config.conditional_disagg_config

    async def setup(self) -> None:
        self._ctx_client = self._client_factory(
            self._ctx_router, ServerRole.CONTEXT, self._config.max_retries
        )
        self._gen_client = self._client_factory(
            self._gen_router, ServerRole.GENERATION, self._config.max_retries
        )

        if self.disagg_cluster_config and self._cluster_storage:
            logger.info("Starting disagg cluster manager")
            self._disagg_cluster_manager = DisaggClusterManager(
                self.disagg_cluster_config, self._cluster_storage
            )
            await self._disagg_cluster_manager.start()
            await self._disagg_cluster_manager.watch_workers(on_event=self._on_worker_event)
            logger.info("Disagg cluster manager started")
        else:
            if self._metadata_server and self._metadata_config:
                logger.info("Starting server monitoring via metadata service")
                await self._ctx_router.start_server_monitoring(
                    self._metadata_config.refresh_interval
                )
                await self._gen_router.start_server_monitoring(
                    self._metadata_config.refresh_interval
                )
            await self._wait_for_all_servers_ready()

    async def teardown(self) -> None:
        await self._ctx_client.shutdown()
        await self._gen_client.shutdown()

        if self._disagg_cluster_manager:
            await self._disagg_cluster_manager.stop()

        if self._metadata_server:
            await self._ctx_router.stop_server_monitoring()
            await self._gen_router.stop_server_monitoring()

    async def _wait_for_all_servers_ready(self) -> None:
        async def check_servers_ready():
            elapsed_time = 0
            interval = self._health_check_interval_secs
            while elapsed_time < self._server_start_timeout_secs:
                _, unready_ctx_servers = await self._ctx_client.check_ready()
                _, unready_gen_servers = await self._gen_client.check_ready()
                if len(unready_ctx_servers) == 0 and len(unready_gen_servers) == 0:
                    logger.info("All servers are ready")
                    return
                logger.info(
                    f"Waiting for servers, context: {unready_ctx_servers}, generation: {unready_gen_servers}"
                )
                await asyncio.sleep(interval)
                elapsed_time += interval

        try:
            await asyncio.wait_for(check_servers_ready(), timeout=self._server_start_timeout_secs)
        except asyncio.TimeoutError:
            raise TimeoutError("Timeout waiting for context and generation servers to be ready")

    async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEventType):
        router_map = {ServerRole.CONTEXT: self._ctx_router, ServerRole.GENERATION: self._gen_router}
        worker_addr = f"{worker_info.host}:{worker_info.port}"
        try:
            router = router_map[worker_info.role]
            if event_type == WatchEventType.SET:
                await router.add_server(worker_addr)
            elif event_type == WatchEventType.DELETE:
                await router.remove_server(worker_addr)
            logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}")
        except KeyError:
            logger.error(
                f"Unknown worker role: {worker_info.role}, Worker {worker_info.worker_id} event: {event_type.name}"
            )

    async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
        if ctx_response:
            if len(ctx_response.choices) != 1:
                raise ValueError(
                    f"Context server returned {len(ctx_response.choices)} choices, expecting 1."
                )
            if ctx_response.choices[0].disaggregated_params is None:
                raise ValueError("Context server did not return disaggregated params")
            if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
                raise ValueError("Invalid disaggregated params in context phase response.")
            return ctx_response
